{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative Adversarial Networks\n",
    "\n",
    "In this notebook, we play with the GAN described in the [course](https://mlelarge.github.io/dataflowr/Slides/GAN/index.html) on a double moon dataset.\n",
    "\n",
    "Then we implement a Conditional GAN and an InfoGAN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# all of these libraries are used for plotting\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Plot the dataset\n",
    "def plot_data(ax, X, Y, color = 'bone'):\n",
    "    plt.axis('off')\n",
    "    ax.scatter(X[:, 0], X[:, 1], s=1, c=Y, cmap=color)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_moons\n",
    "X, y = make_moons(n_samples=2000, noise=0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = X.shape[0]\n",
    "Y = np.ones(n_samples)\n",
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "\n",
    "plot_data(ax, X, Y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "use_gpu = torch.cuda.is_available()\n",
    "def gpu(tensor, gpu=use_gpu):\n",
    "    if gpu:\n",
    "        return tensor.cuda()\n",
    "    else:\n",
    "        return tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A simple GAN\n",
    "\n",
    "We start with the simple GAN described in the course."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "z_dim = 32\n",
    "hidden_dim = 128\n",
    "\n",
    "net_G = nn.Sequential(nn.Linear(z_dim,hidden_dim),\n",
    "                     nn.ReLU(), nn.Linear(hidden_dim, 2))\n",
    "\n",
    "net_D = nn.Sequential(nn.Linear(2,hidden_dim),\n",
    "                     nn.ReLU(),\n",
    "                     nn.Linear(hidden_dim,1),\n",
    "                     nn.Sigmoid())\n",
    "\n",
    "net_G = gpu(net_G)\n",
    "net_D = gpu(net_D)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training loop as described in the course, keeping the losses for the discriminator and the generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 50\n",
    "lr = 1e-4\n",
    "nb_epochs = 1000\n",
    "\n",
    "optimizer_G = torch.optim.Adam(net_G.parameters(),lr=lr)\n",
    "optimizer_D = torch.optim.Adam(net_D.parameters(),lr=lr)\n",
    "\n",
    "loss_D_epoch = []\n",
    "loss_G_epoch = []\n",
    "for e in range(nb_epochs):\n",
    "    np.random.shuffle(X)\n",
    "    real_samples = torch.from_numpy(X).type(torch.FloatTensor)\n",
    "    loss_G = 0\n",
    "    loss_D = 0\n",
    "    for t, real_batch in enumerate(real_samples.split(batch_size)):\n",
    "            #improving D\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        fake_batch = net_G(z)\n",
    "        D_scores_on_real = net_D(gpu(real_batch))\n",
    "        D_scores_on_fake = net_D(fake_batch)\n",
    "            \n",
    "        loss = -torch.mean(torch.log(1-D_scores_on_fake) + torch.log(D_scores_on_real))\n",
    "        optimizer_D.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_D.step()\n",
    "        loss_D += loss\n",
    "                    \n",
    "            # improving G\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        fake_batch = net_G(z)\n",
    "        D_scores_on_fake = net_D(fake_batch)\n",
    "            \n",
    "        loss = -torch.mean(torch.log(D_scores_on_fake))\n",
    "        optimizer_G.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_G.step()\n",
    "        loss_G += loss\n",
    "           \n",
    "    loss_D_epoch.append(loss_D)\n",
    "    loss_G_epoch.append(loss_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(loss_D_epoch)\n",
    "plt.plot(loss_G_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "z = gpu(torch.empty(n_samples,z_dim).normal_())\n",
    "fake_samples = net_G(z)\n",
    "fake_data = fake_samples.cpu().data.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "all_data = np.concatenate((X,fake_data),axis=0)\n",
    "Y2 = np.concatenate((np.ones(n_samples),np.zeros(n_samples)))\n",
    "plot_data(ax, all_data, Y2)\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It looks like the GAN is oscillating. Try again with lr=1e-3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can generate more points:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = gpu(torch.empty(10*n_samples,z_dim).normal_())\n",
    "fake_samples = net_G(z)\n",
    "fake_data = fake_samples.cpu().data.numpy()\n",
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "all_data = np.concatenate((X,fake_data),axis=0)\n",
    "Y2 = np.concatenate((np.ones(n_samples),np.zeros(10*n_samples)))\n",
    "plot_data(ax, all_data, Y2)\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conditional GAN\n",
    "\n",
    "We are now implementing a [conditional GAN](https://arxiv.org/abs/1411.1784).\n",
    "\n",
    "We start by separating the two half moons in two clusters as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, Y = make_moons(n_samples=2000, noise=0.05)\n",
    "n_samples = X.shape[0]\n",
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "plot_data(ax, X, Y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The task is now given a white or black label to generate points in the corresponding cluster.\n",
    "\n",
    "Both the generator and the discriminator take in addition a one hot encoding of the label. The generator will now generate fake points corresponding to the input label. The discriminator, given a pair of sample and label should detect if this is a fake or a real pair."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "z_dim = 32\n",
    "hidden_dim = 128\n",
    "label_dim = 2\n",
    "\n",
    "\n",
    "class generator(nn.Module):\n",
    "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
    "        super(generator,self).__init__()\n",
    "        self.net = nn.Sequential(nn.Linear(z_dim+label_dim,hidden_dim),\n",
    "                     nn.ReLU(), nn.Linear(hidden_dim, 2))\n",
    "        \n",
    "    def forward(self, input, label_onehot):\n",
    "        x = torch.cat([input, label_onehot], 1)\n",
    "        return self.net(x)\n",
    "    \n",
    "class discriminator(nn.Module):\n",
    "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
    "        super(discriminator,self).__init__()\n",
    "        self.net =  nn.Sequential(nn.Linear(2+label_dim,hidden_dim),\n",
    "                     nn.ReLU(),\n",
    "                     nn.Linear(hidden_dim,1),\n",
    "                     nn.Sigmoid())\n",
    "        \n",
    "    def forward(self, input, label_onehot):\n",
    "        x = torch.cat([input, label_onehot], 1)\n",
    "        return self.net(x)\n",
    "        \n",
    "\n",
    "net_CG = gpu(generator())\n",
    "net_CD = gpu(discriminator())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test the network and in particular the type of the inputs...\n",
    "\n",
    "Hint:  https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "#\n",
    "# your code here\n",
    "#\n",
    "#\n",
    "#\n",
    "fake_batch = net_CG(z, label_onehot)\n",
    "print(fake_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_CD(fake_batch, label_onehot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You need to code the training loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 50\n",
    "lr = 1e-3\n",
    "nb_epochs = 1000\n",
    "\n",
    "optimizer_CG = torch.optim.Adam(net_CG.parameters(),lr=lr)\n",
    "optimizer_CD = torch.optim.Adam(net_CD.parameters(),lr=lr)\n",
    "loss_D_epoch = []\n",
    "loss_G_epoch = []\n",
    "for e in range(nb_epochs):\n",
    "    rperm = np.random.permutation(X.shape[0]);\n",
    "    np.take(X,rperm,axis=0,out=X);\n",
    "    np.take(Y,rperm,axis=0,out=Y);\n",
    "    real_samples = torch.from_numpy(X).type(torch.FloatTensor)\n",
    "    real_labels = torch.from_numpy(Y).type(torch.LongTensor)\n",
    "    loss_G = 0\n",
    "    loss_D = 0\n",
    "    for real_batch, real_batch_label in zip(real_samples.split(batch_size),real_labels.split(batch_size)):\n",
    "            #improving D\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        #\n",
    "        # your code here\n",
    "        # \n",
    "        #\n",
    "        #\n",
    "        #\n",
    "        optimizer_CD.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_CD.step()\n",
    "        loss_D += loss\n",
    "        \n",
    "            \n",
    "            # improving G\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        #\n",
    "        #\n",
    "        optimizer_CG.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer_CG.step()\n",
    "        loss_G += loss\n",
    "                    \n",
    "    loss_D_epoch.append(loss_D)\n",
    "    loss_G_epoch.append(loss_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(loss_D_epoch)\n",
    "plt.plot(loss_G_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = gpu(torch.empty(n_samples,z_dim).normal_())\n",
    "label = torch.LongTensor(n_samples,1).random_() % label_dim\n",
    "label_onehot = torch.FloatTensor(n_samples, label_dim).zero_()\n",
    "label_onehot = gpu(label_onehot.scatter_(1, label, 1))\n",
    "fake_samples = net_CG(z, label_onehot)\n",
    "fake_data = fake_samples.cpu().data.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "plot_data(ax, fake_data, label.squeeze().numpy())\n",
    "plot_data(ax, X, Y, 'spring')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Info GAN\n",
    "\n",
    "Implement the [algorithm](https://arxiv.org/abs/1606.03657).\n",
    "\n",
    "This time, you do not have access to the labels but you know there are two classes. The idea is then to provide as in the conditional GAN a random label to the generator but in opposition to the conditional GAN, the discriminator cannot take as input the label (since they are not provided to us) but instead the discriminator will predict a label and this predictor can be trained on fake samples!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "z_dim = 32\n",
    "hidden_dim = 128\n",
    "label_dim = 2#attention changer est_label si plus de deux classes...\n",
    "\n",
    "\n",
    "class Igenerator(nn.Module):\n",
    "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
    "        super(Igenerator,self).__init__()\n",
    "        self.net = nn.Sequential(nn.Linear(z_dim+label_dim,hidden_dim),\n",
    "                     nn.ReLU(), nn.Linear(hidden_dim, 2))\n",
    "        \n",
    "    def forward(self, input, label_onehot):\n",
    "        x = torch.cat([input, label_onehot], 1)\n",
    "        return self.net(x)\n",
    "    \n",
    "class Idiscriminator(nn.Module):\n",
    "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
    "        super(Idiscriminator,self).__init__()\n",
    "        self.fc1 = nn.Linear(2,hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim,1)\n",
    "        self.fc3 = nn.Linear(hidden_dim,1)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        x = F.relu(self.fc1(input))\n",
    "        output = F.sigmoid(self.fc2(x))\n",
    "        est_label = F.sigmoid(self.fc3(x)) \n",
    "        return output, est_label\n",
    "        \n",
    "\n",
    "net_IG = gpu(Igenerator())\n",
    "net_ID = gpu(Idiscriminator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 50\n",
    "lr = 1e-3\n",
    "nb_epochs = 1000\n",
    "loss_fn = nn.BCELoss()\n",
    "\n",
    "optimizer_IG = torch.optim.Adam(net_IG.parameters(),lr=lr)\n",
    "optimizer_ID = torch.optim.Adam(net_ID.parameters(),lr=lr)\n",
    "loss_D_epoch = []\n",
    "loss_G_epoch = []\n",
    "for e in range(nb_epochs):\n",
    "    \n",
    "    rperm = np.random.permutation(X.shape[0]);\n",
    "    np.take(X,rperm,axis=0,out=X);\n",
    "    #np.take(Y,rperm,axis=0,out=Y);\n",
    "    real_samples = torch.from_numpy(X).type(torch.FloatTensor)\n",
    "    #real_labels = torch.from_numpy(Y).type(torch.LongTensor)\n",
    "    loss_G = 0\n",
    "    loss_D = 0\n",
    "    for real_batch in real_samples.split(batch_size):\n",
    "            #improving D\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        #\n",
    "            \n",
    "        lossD = -torch.mean(torch.log(1-D_scores_on_fake) + torch.log(D_scores_on_real))\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        #\n",
    "        #\n",
    "            \n",
    "            # improving G\n",
    "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        #\n",
    "        \n",
    "        lossG = -torch.mean(torch.log(D_scores_on_fake))\n",
    "        #\n",
    "        # your code here\n",
    "        #\n",
    "        #\n",
    "        optimizer_IG.zero_grad()\n",
    "        lossG.backward()\n",
    "        optimizer_IG.step()\n",
    "        loss_G += lossG\n",
    "      \n",
    "            \n",
    "    loss_D_epoch.append(loss_D)\n",
    "    loss_G_epoch.append(loss_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(loss_D_epoch)\n",
    "plt.plot(loss_G_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = gpu(torch.empty(n_samples,z_dim).normal_())\n",
    "label = torch.LongTensor(n_samples,1).random_() % label_dim\n",
    "label_onehot = torch.FloatTensor(n_samples, label_dim).zero_()\n",
    "label_onehot = gpu(label_onehot.scatter_(1, label, 1))\n",
    "fake_samples = net_IG(z, label_onehot)\n",
    "fake_data = fake_samples.cpu().data.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
    "plot_data(ax, fake_data, label.squeeze().numpy())\n",
    "plot_data(ax, X, Y, 'spring')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
